function  output = oplsda(X,Y,nOSC, nPLS, Xtest, Yt,opt)
% output = opls(X,Y,nOSC, nPLS, Xtest, Yt)
% Orthogonal Projection to Latent Structure (OPLS)
% X: spectra matrix in training set
% Y: class membership matrix in training set
% nOSC: the number of orthogonal components, set to 1 by default
% nPLS: the number of PLS components, set to 1 by default
% Xtest: the spectra matrix in test set (optional)
% Yt: the concentration matrix in test set (optional)
% output is a structure array with following fields:
% output
%       .TrainingSet: the centres of training set
%                .X_centre: the centre of X matrix
%                .Y_centre: the centre of Y matrix
%       .OSCmodel: the parameters of the OSC filter
%                .W_orth: orthogonal weight
%                .T_orth: orthogonal scores
%                .P_orth: orthogonal loadings
%                .Xcal: the processed training data by OSC filter
%                .R2Xcal: variation explained by Xcal over X
%       .PLSmodel: the parameters of the PLS model
%                .T: PLS X scores
%                .P: PLS X loadings
%                .Q: PLS Y loadings
%                .W: PLS weight matrix
%                .B: PLS regression coefficient
%                .Yhat_cal: the numeric prediction of training set
%                .C_cal: class membership prediction of training set
%                .ccr_auto: correct classification rate of training set
%
%       .Testset: the results of test set if Xtest is provided
%                .Xtest: the processed test data by OSC filter
%                .R2test: variation explayed by processed Xtest over
%                   original Xtest
%                .Ypred: the predicted concentration of test set
%                .C_test: class membership prediction of test set
%                .ccr_test: correct classification rate of test set, only
%                given if Yt is also provided
% By Yun Xu, 2016

[m,n] = size(X);
[m2,c]=size(Y);
if m~=m2
    error('The number of rows in X and Y must be the same!')
end
if c==1
    class_train = Y;
    unique_cls = unique(Y);
    c = length(unique_cls);
    Ymat=zeros(m, c);
    for i=1:length(unique_cls)
        Ymat(Y==unique_cls(i),i)=1;
    end
    Y=Ymat;
else
    [~, class_train]=max(Y,[],2);
end
if nargin<3
    nOSC=1;
    nPLS=1;
end

if nargin<4
    nPLS=1;
end

if nargin<5
    Xtest=[];
    Yt=[];
end

if nargin<6
    Yt=[];
end

% Preprocessing step for X and Y
[X_mc,X_scale,sx]=prepfn(X, opt.prepr{1});    %Preprocessing of X matrix
[Y_mc,Y_scale,sy]=prepfn(Y, opt.prepr{1});    %Preprocessing of Y matrix

output.TrainingSet.X_centre = X_scale;
output.TrainingSet.Y_centre = Y_scale;

% build the OSC filter
if isempty(Xtest)
    OSC_model = oscF(X_mc, Y_mc, nOSC);
else
    [m2,n2] = size(Xtest);
    if n2 ~= n
        error('The number of columns in Xtest and X must be the same!')
    end
    Xtest_mc = Xtest - repmat(X_scale, m2,1);
    OSC_model = osc(X_mc,Y_mc,nOSC, Xtest_mc);
end

% Output OSC model parameters
output.OSCmodel.W_orth=OSC_model.W_orth;
output.OSCmodel.T_orth=OSC_model.T_orth;
output.OSCmodel.P_orth=OSC_model.P_orth;
output.OSCmodel.Xcal=OSC_model.Xcal;
output.OSCmodel.R2Xcal=OSC_model.R2Xcal;

% Build PLS model on OSC corrected data
Xcal=OSC_model.Xcal;
% [T,P,Q,W,b] = pls(Xcal, Y_mc, nPLS);
[T,P,~,R,Q,b,~, ssq,~]=pls2sim(Xcal, Y_mc, nPLS);
Yhat_cal=Xcal*b;
Yhat_cal=unprepfn(Yhat_cal,opt.prepr{1},Y_scale,sy);
% Yhat_cal = Yhat_cal + repmat(Y_scale,m,1);

% Calculating VIP scores
[vip, viptot]=myvip2(T,[],R,Q,b);
output.vip=vip;
output.TotVip=viptot;

% R2 Calculation
% Training R2
YResiduals=Y-Yhat_cal;
RSS=sum(sum(YResiduals.^2)); % Sum of Squared Errors for Test Set
TSS=sum(sum((Y-repmat(mean(Y), size(Y,1), 1)).^2)); % Total Sum of Squares for Test Set
output.R2=1-(RSS./TSS(1));

[~, C_cal] = max(Yhat_cal,[],2);
% R2Y=1-sum(sum((Yhat_cal-Y).^2))/sum(sum((Y-repmat(mean(Y), size(Y,1),1)).^2));
% RMSEC=sqrt(sum((Yhat_cal-Y).^2)./size(Y,1));
ccr_auto = length(find(C_cal == class_train))/length(class_train);
output.PLSmodel.T=T;
output.PLSmodel.loadings={[OSC_model.P_orth P] Q};
output.PLSmodel.W=R;
output.PLSmodel.B=b;
output.PLSmodel.Yhat_cal = Yhat_cal;
output.PLSmodel.C_cal = C_cal;
% output.PLSmodel.R2=R2Y;
% output.PLSmodel.RMSEC=RMSEC;
output.PLSmodel.ccr_auto = ccr_auto;

% Backscale the Loadings and VIP score when using Autoscale
switch opt.prepr{1}
    case 'auto'
        output.TotVip=output.TotVip.*sx';
        output.PLSmodel.loadings{1,1}=output.PLSmodel.loadings{1,1}.*sx';

        % output.PLSmodel.T=output.PLSmodel.T.*sx';
        % output.OSCmodel.T_orth=output.OSCmodel.T_orth.*sx';
end

% Explained variance
AxesXPlot = {};
for i=1:nPLS
    AxesXPlot{i,1} = sprintf('Scores on PC %d (%4g%%)',i,round((ssq{1, 1}(i,1)),2)); %#ok<AGROW>
end
output.AxesXPlot= AxesXPlot;

%Correcting new samples if given
if ~isempty(Xtest)
    output.TestSet.Xtest=OSC_model.Xtest;
    % Yhat = plspred2(Ztest,P,Q,R,b,nPLS);
    Yhat = OSC_model.Xtest*b;
    Ypred=unprepfn(Yhat,opt.prepr{1},Y_scale,sy);
    output.Yhat_val = Ypred;
    [~, C_test] = max(Ypred,[],2);
    output.TestSet.Ypred = Ypred;
    output.TestSet.C_test = C_test;
end

% Checking the threshold
[~,class_true] = max(Yt'); %#ok<*UDIM>
resthr = plsdafindthr(Ypred,class_true');
output.threshold = resthr;



detailedoutput = 'on';
facts=min([nPLS n m-1]);
if strcmp(detailedoutput,'on')
    q=sum((Xcal-T*P').^2,2);
    t2=diag((m-1)*T/(T'*T)*T'); %tsquare
    t2lim=facts*(m-1)*finv(0.95,facts,m-facts)./(m-facts);
    sres=sum(q/(m-1));
    % Computation of qlim using J-M approx
    theta1 = sum(sres);
    theta2 = sum(sres.^2);
    theta3 = sum(sres.^3);
    if theta1==0
        qlim = 0;
    else
        h0     = 1-2*theta1*theta3/3/(theta2.^2);
        if h0<0.001
            h0 = 0.001;
        end
        ca    = sqrt(2)*erfinv(2*0.95-1);
        h1    = ca*sqrt(2*theta2*h0.^2)/theta1;
        h2    = theta2*h0*(h0-1)/(theta1.^2);
        qlim = theta1*(1+h1+h2).^(1/h0);
    end

    model.tsq=t2;
    model.qres=q;
    model.tsqlim=t2lim;
    model.qlim=qlim;
    model.tsqr=t2/t2lim;
    model.qr=q/qlim;
end



if ~isempty(Yt)
    [n,c]=size(Yt);
    if c==1
        class_test=Yt;
        unique_cls = unique(Yt);
        c = length(unique_cls);
        Ymat=zeros(n, c);
        for i=1:length(unique_cls)
            Ymat(Yt==unique_cls(i),i)=1;
        end
        Yt=Ymat; %#ok<NASGU>
    else
        [~, class_test]=max(Yt,[],2);
    end
    ccr_test = length(find(C_test==class_test))/length(class_test);

    output.YResiduals=Yt-Ypred;

    RSS=sum(sum(output.YResiduals.^2));
    TSS=sum(sum((Yt-repmat(mean(Yt), size(Yt,1), 1)).^2));

    output.R2_test=1-(RSS./TSS(1));
    PRESS = sum(sum((Yt-Ypred).^2));
    output.Q2=1-PRESS./TSS(1);

    output.RMSEP=sqrt(sum(output.YResiduals.^2)/size(Ypred,1));
    output.TestSet.ccr_test = ccr_test;
end
end

